# -*- coding: utf-8 -*-
"""
Created on Wed May 18 11:34:57 2022

@author: LSY
"""

from Models import PCCPlant
import casadi
import numpy as np
import matplotlib.pyplot as plt
from numpy import random
import timeit
from os.path import dirname, join as pjoin
import mpctools as mpc
import scipy.io as spio

from Sens_proj_cc import SenMatrix, SensAnal

np.set_printoptions(precision=4)
np.set_printoptions(linewidth=400)

# ---------------------------------------------------------------------
# Simulation parameters
# ---------------------------------------------------------------------
random.seed(927)
rank_xp = True
dt = 60.0 * 10  # Plant simulation time step or sampling time, s # must be same as the sampling time in the data generation
tf = 3600 * 10  # Total simulation time, s # over-ridden
Nt = int(round(tf / dt))  # Number of simulation points # over-ridden
tplot = np.arange(Nt)*10
# ---------------------------------------------------------------------
# Steady state and initial values of the plant
# ---------------------------------------------------------------------
xs = np.load('xss.npy')
zs = np.load('zss.npy')
us = np.load('uss.npy')
us = np.append(us,1.0)  # re-write us to include new input
x0 = np.load('x0.npy')  # 103
z0 = np.load('z0.npy')  # 7
# ---------------------------------------------------------------------
# Model parameters
# ---------------------------------------------------------------------
Nx = 103  # Number of differential states
Nz = 7  # Number of algebraic states
Nd = 2  # Number of disturbances
Nu = 2+1  # Number of inputs: 2 inputs + 1 disturbance (flue gas flow rate [0.8 1.2])
Ny = 23  # Number of outputs
Nw = Nx 
Nv = Ny

tt = [] # np.zeros(Nt)
Y = [] # np.zeros((Nt, Ny))
X = [] # np.zeros((Nt, Nx))
Z = [] # np.zeros((Nt, Nz))
U = [] # np.zeros((Nt, Nu))

Nsim = 9 # Nt
sigma_w = 0
sigma_v = 0.001
w = sigma_w*np.random.randn(Nt,Nw)
random.seed(927) #123
v = sigma_v*np.random.randn(Nt,Nv)

# initial state
x0_scale = np.zeros(Nx)
delta_x = np.loadtxt('delta.txt')
min_x = np.loadtxt('min.txt')
for i in range(np.size(delta_x)):
    if delta_x[i] <= 1e-10:
        delta_x[i] = 1
        min_x[i] = 0
for j in range(Nx):
    x0_scale[j] = (x0[j]-min_x[j])/delta_x[j]
print(x0)

X += [x0_scale]

us[0] = us[0] 
us[1] = us[1]
us[2] = 1.0 


usim = np.zeros((Nt,Nu))
xsim = np.zeros((Nt,Nx))
zsim = np.zeros((Nt,Nz))
ysim = np.zeros((Nt,Ny))

#### sensor set(measurements)
def measure(x):
    c_extend = np.zeros([Ny,Nx])
    c_extend[0:5,30:35] = np.eye(5)     # C_G,CO2 (absorption)
    c_extend[5:10,45:50] = np.eye(5)    # T_G
    c_extend[10:15,80:85] = np.eye(5)   # C_G,CO2 (desorption)
    c_extend[15:20,95:100] = np.eye(5)  # T_G
    c_extend[20:22,100:102] = np.eye(2) # T (Heat exchanger)
    c_extend[22,102] = np.eye(1)        # T (Reboiler)
    y = np.dot(c_extend,x)
    return y

# ---------------------------------------------------------------------
# Model setup
# ---------------------------------------------------------------------
# Initialize PCC plant
#pccplant = PCCPlant.PCCPlant()
pccplant = PCCPlant.PCCPlant()
# Declare Symbolic variables for DAE system
t = casadi.SX.sym('t')  # Time
xd = casadi.SX.sym('x', Nx)  # Differential states
xz = casadi.SX.sym('z', Nz)  # Algebraic states
u = casadi.SX.sym('u', Nu)  # Control inputs
d = casadi.SX.sym('d', Nd)  # Disturbances
xddot = casadi.SX.sym('xdot', Nx)  # xdot


# Build integrator for plant simulation
ode = pccplant.getODE(xd, xz, u)
alg = pccplant.getALG(xd, xz)
dae = {'x': xd, 'z': xz, 'p': u, 'ode': ode, 'alg': alg}
opts = {"tf": dt, "abstol": 1e-5}  # interval length
plant_sim = casadi.integrator('I', 'idas', dae, opts)

#for t in range(Nsim):
#    xsim[t+1,:] = plant_sim(x0=xsim[t,:], z0=z0, p=us) #+ w[t,:]

u_max = np.array([0.5812*1.2, 0.1357*1.1, 1*1.1])  
u_min = np.array([0.5812*0.84, 0.1357*0.69, 1*0.4])
for i in range(Nt):
    
    tt += [i*dt/60]  # Current sample time
#    ust = us + 0.0001*np.random.randn(1,3) # us = us + 0.01*np.random.randn(1,3) is variable variance
## the following is input constraint  
#    for j in range (Nu):
#        if ust[:,j] >= u_max[j]:
#            ust[:,j] = u_max[j]
#        if ust[:,j] <= u_min[j]:
#            ust[:,j] = u_min[j]
    U += [u_min]
    
    sol = plant_sim(x0=X[i], z0=z0, p=U[i]) 
    
    X += [sol["xf"].full()[:, 0]]
    Z += [sol["zf"].full()[:, 0]] 
    
    #initila guess for algebraic states
X = np.array(X)
Z = np.array(Z)
U = np.array(U)

for i in range(Nt):
    xsim[i,:] = X[i,:] + w[i,:]
    zsim[i,:] = Z[i,:]
    usim[i,:] = U[i,:]
# Number of sampling time
Delta = 60.0 # Sampling time 15min = dt
#usim = np.zeros((Nsim+1,Nu))
#usim[:,0] = us[0] * 1.1
#usim[:,1] = us[1] * 1.2
#usim[:,1] = us[2] * 1.1
x_actual = np.zeros((Nt,Nx))      
for i in range(Nx):
    x_actual[:,i] = xsim[:,i] * delta_x[i] + min_x[i]  

#xs = np.load('xss.npy')
#min_x = 0.5*xs
#delta_x = xs

#xsim_scale = np.zeros((Nt,Nx))
#for i in range(Nt):
#    for j in range(Nx):
#        xsim_scale[i,j] = (xsim[i,j]-min_x[j])/delta_x[j]


F = mpc.getCasadiFunc(pccplant.getODE,[Nx,Nz,Nu],["x","z","u"],"ccmodel")    # “getODE” is the ODE model of pccplant
F_rk4 = mpc.getCasadiFunc(pccplant.getODE,[Nx,Nz,Nu],["x","z","u"],"ccmodel", rk4=True, Delta=Delta)  # use rk4
H = mpc.getCasadiFunc(measure,[Nx],["x"],"measurement")


print("Calculate the sensitive matrix with all measurements ......")
Tstart = timeit.default_timer()


#ysim = np.zeros((Nsim, Ny))
for t in range(Nt):
    ysim[t,:] = measure(X[t,:])+v[t,:] # Get zero-noise measurement.


sm_all = SenMatrix(F,H,xsim[0:Nsim,:],ysim[0:Nsim,:],usim[0:Nsim-1,:],zsim[0:Nsim-1,:],Nsim-1) 

de_zeros = [0,1,2,3,4,50,51,52,53,54,75,76,77,78,79]
sm_all = np.delete(sm_all,de_zeros,axis=1) 


y_act = np.zeros((Nt,Ny))
for t in range(Nt):
    y_act[t,:] = measure(x_actual[t,:])+v[t,:] # Get zero-noise measurement.


price = np.array([20, 20, 20, 20, 20, 1, 1, 1, 1, 1, 20, 20, 20, 20, 20, 1, 1, 1, 1, 1, 1, 1, 1]) #right

Tend = timeit.default_timer()
print("The sensitive matrix sm = \n",sm_all)
print('Cost time:', str(Tend-Tstart),'s')

Nx = Nx - 15
#### The sensitive degree with all the measurments:
Tstart = timeit.default_timer()
Flag = SensAnal(sm_all,Nx,Ny,sigma_w,sigma_v,Nsim-1,rank_xp)   
Rank_Sen = Flag[0]
print('Rank of Sensiti (total):', Rank_Sen)
Sen_deg = Flag[-1]
print("Sen degree (total):",Sen_deg)
Tend = timeit.default_timer()
print('Cost time:', str(Tend-Tstart),'s')
########################################################


############ If these measurements have been removed:
M_selected = np.empty([0,0])
#M_selected = np.array([0,5,6,7,8,9,15,16,18,19])
N_known = 0  # The number of selected sensors

z_curr = np.zeros((1,Ny))
cost = np.dot(z_curr,price)

print("These measurements have been selected:\n", M_selected)
print("cost:", cost)  

                
#t_list = []      
#for ni in range(Nsim):        
#    for nj in M_selected:
#        t_list.append(int(nj+ni*(23)))     
#SM_I = sm_all[t_list,:]
#Flag = SensAnal(SM_I,Nx,N_known,sigma_w,sigma_v,Nsim-1,rank_xp)
#Rank_Sen_add = Flag[0]
#print('Rank of Sensiti:', Rank_Sen_add)
#Sen_state_sort = Flag[1:-1]
#print("Sen state sort:",Sen_state_sort)
#Sen_deg = Flag[-1]
#print("Sen degree:",Sen_deg)

################################ begin select ########################################################
print("------------------------------------------Begin Select-------------------------------------")

Rank_Sen_final = np.zeros((1,1))
Sen_deg_final = np.zeros((1,1))
cost_final =  np.zeros((1,1))
CPI_final=  np.zeros((1,1))


for k in range(Nx):
    print("-----------------------------------------------------------------------------------")
    print(N_known+1,"- th add")
    Rank_Sen_selected_max = np.zeros((1,1))
    Sen_deg_opt = np.zeros((1,1))
    cost_opt = np.zeros((1,1))    
    CPI_max = np.zeros((1,1))
    N_known = N_known + 1
    M_opt_add = np.zeros((1,1))
    
    flagy = 0
    for i in range(Ny):
        if i in M_selected:
            continue
        else:
            if len(M_selected):
                Madd_try = np.concatenate((M_selected,i),axis = None)
            else:
                Madd_try = np.array([i])
            Madd_try = np.sort(Madd_try)
            print('try these mesaurements',Madd_try)
            
        z_curr = np.zeros((1,Ny))
        for nj in Madd_try:
            z_curr[:,int(nj)] = 1            
        cost = np.dot(z_curr,price)
        print("cost:", cost)  
        if cost > 113:
            continue

        flagy = 1
                
        t_list = []      
        for ni in range(Nsim):        
            for nj in Madd_try:
                t_list.append(int(nj+ni*(23)))     
        SM_I = sm_all[t_list,:]

        Flag = SensAnal(SM_I,Nx,N_known,sigma_w,sigma_v,Nsim-1,rank_xp)
        
        Rank_Sen_add = Flag[0]
        print('Rank of Sensiti:', Rank_Sen_add)
#        Sen_state_sort = Flag[1:-1]
#        print("Sen state sort:",Sen_state_sort)
        Sen_deg = Flag[-1]
                
        CPI = Sen_deg/cost
        print("Sen degree:",Sen_deg)
        print("CPI:", CPI)    
        
        if CPI >= CPI_max: 
            Sen_deg_opt[:] = Sen_deg
            cost_opt[:] = cost
            CPI_max[:] = CPI
            Rank_Sen_selected_max[0,0] = Rank_Sen_add
            M_opt_add [:] = i
        
    if flagy == 0:
        break
    else:        
        Rank_Sen_final[0,0] = Rank_Sen_selected_max
        Sen_deg_final[0,0] = Sen_deg_opt
        cost_final[0,0] = cost_opt
        CPI_final[0,0] = CPI_max

        print("new  added:",M_opt_add)
        M_selected = np.concatenate(((M_selected,M_opt_add)), axis = None)  
        print("selected measurements:\n", M_selected)   

        
print("Finially selected:",M_selected) 
print("Rank:",Rank_Sen_final) 
print("Degree:",Sen_deg_final)
print("cost:",cost_final)
print("CPI:",CPI_final)
